{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Lesson 5: Fine-Tuning"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<p style=\"background-color:#fff6e4; padding:15px; border-width:3px; border-color:#f5ecda; border-style:solid; border-radius:6px\"> ⏳ <b>Note <code>(Kernel Starting)</code>:</b> This notebook takes about 30 seconds to be ready to use. You may start and watch the video while you wait.</p>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* In this classroom, the libraries have been already installed for you.\n",
    "* If you would like to run this code on your own machine, you need to install the following:\n",
    "    ```\n",
    "    !pip install -q accelerate torch diffusers transformers comet_ml\n",
    "    ```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set up Comet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* Here you will use the [HuggingFace DreamBooth](https://huggingface.co/docs/diffusers/en/training/dreambooth) training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "import comet_ml"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "comet_ml.init(anonymous=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import and prepare the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 114
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    model_name = 'stabilityai/stable-diffusion-xl-base-1.0'\n",
    "else:\n",
    "    model_name = './models/runwayml/stable-diffusion-v1-5'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 233
   },
   "outputs": [],
   "source": [
    "# Define hyperparameters\n",
    "\n",
    "hyperparameters = {\n",
    "    \"instance_prompt\": \"a photo of a [V] man\",\n",
    "    \"class_prompt\": \"a photo of a man\",\n",
    "    \"seed\": 4329,\n",
    "    \"pretrained_model_name_or_path\": model_name,\n",
    "    \"resolution\": 1024 if torch.cuda.is_available() else 512,\n",
    "    \"num_inference_steps\": 50,\n",
    "    \"guidance_scale\": 5.0,\n",
    "    \"num_class_images\": 200,\n",
    "    \"prior_loss_weight\": 1.0\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* Set new **Comet** experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "experiment = comet_ml.Experiment()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "from utils import DreamBoothTrainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "trainer = DreamBoothTrainer(hyperparameters)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Note\n",
    "- The code that generates images requires a GPU to run.\n",
    "- The code is left here in markdown, but if you have access to GPUs outside of the classroom, you can run it there.\n",
    "- In the classroom, you'll still be able to follow along by retrieving the generated images from the experiment tracking tool (Comet)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```Python\n",
    "# To run the training pipeline\n",
    "trainer.generate_class_images()\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```Python\n",
    "# To see the content of generate_class_image\n",
    "??trainer.generate_class_images\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Get class images (using artifacts)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "import shutil"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "# Get images\n",
    "class_artifact = experiment.get_artifact('ckaiser/class-images-15')\n",
    "class_artifact.download('./')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "shutil.unpack_archive('./class.zip', './class')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">Note: the images referenced in this notebook have already been uploaded to the Jupyter directory, in this classroom, for your convenience. For further details, please refer to the **Appendix** section located at the end of the lessons."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Print some images\n",
    "trainer.display_images(\"class\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* Get the instance dataset (images of Andrew)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "andrew_artifact = experiment.get_artifact('ckaiser/andrew-dataset')\n",
    "andrew_artifact.download('./')\n",
    "\n",
    "shutil.unpack_archive('./andrew-dataset.zip', './instance')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Print some images\n",
    "trainer.display_images(\"instance\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initialize the model\n",
    "- It will take some time (several minutes) to initialize the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "tokenizer, text_encoder, vae, unet = trainer.initialize_models()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> Note: see the video lesson for the LoRA explanation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 131
   },
   "outputs": [],
   "source": [
    "# Add noise to generate images in Stable Diffusion\n",
    "from diffusers import DDPMScheduler\n",
    "\n",
    "noise_scheduler = DDPMScheduler.from_pretrained(\n",
    "    trainer.hyperparameters.pretrained_model_name_or_path,\n",
    "    subfolder=\"scheduler\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "unet = trainer.initialize_lora(unet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "optimizer, params_to_optimize = trainer.initialize_optimizer(unet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "# Initialize the datasets\n",
    "train_dataset, train_dataloader = trainer.prepare_dataset(tokenizer, text_encoder)\n",
    "lr_scheduler = trainer.initialize_scheduler(train_dataloader, optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "unet, optimizer, train_dataloader, lr_scheduler = trainer.accelerator.prepare(\n",
    "    unet, optimizer, train_dataloader, lr_scheduler)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "total_batch_size = \\\n",
    "    trainer.hyperparameters.train_batch_size * \\\n",
    "    trainer.hyperparameters.gradient_accumulation_steps"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Note\n",
    "- Starting from this point, the code demonstrated by the instructor will not execute in this notebook due to computational resource constraints. However, we provide the code here for you to run if you have access to a GPU or similar resources.\n",
    "- Thank you for your understanding as we work to provide free and accessible courses."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```Python\n",
    "from tqdm import tqdm\n",
    "\n",
    "global_step = 0\n",
    "epoch = 0\n",
    "\n",
    "progress_bar = tqdm(\n",
    "    range(0, trainer.hyperparameters.max_train_steps),\n",
    "    desc=\"Steps\"\n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```Python\n",
    "for epoch in range(0, trainer.hyperparameters.num_train_epochs):\n",
    "    unet.train()\n",
    "\n",
    "    for step, batch in enumerate(train_dataloader):\n",
    "        with trainer.accelerator.accumulate(unet):\n",
    "            pixel_values = batch[\"pixel_values\"].to(dtype=vae.dtype)\n",
    "            model_input = vae.encode(pixel_values).latent_dist.sample()\n",
    "            model_input = model_input * vae.config.scaling_factor\n",
    "\n",
    "            noise = torch.randn_like(model_input)\n",
    "            bsz, channels, height, width = model_input.shape\n",
    "\n",
    "            timesteps = torch.randint(\n",
    "                0,\n",
    "                noise_scheduler.config_num_train_timesteps,\n",
    "                (bsz,),\n",
    "                device=model_input.device\n",
    "            )\n",
    "\n",
    "            timesteps = timesteps.long()\n",
    "            noisy_model_input = noise_scheduler.add_noise(\n",
    "                model_input,\n",
    "                noise,\n",
    "                timesteps\n",
    "            )\n",
    "\n",
    "            encoder_hidden_states = batch[\"input_ids\"]\n",
    "\n",
    "            model_predict = unet(\n",
    "                noisy_model_input,\n",
    "                timesteps,\n",
    "                encoder_hidden_states,\n",
    "                return_dic=False,\n",
    "            )[0]\n",
    "\n",
    "            target = noise\n",
    "\n",
    "            model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)\n",
    "            target, target_prior = torch.chunk(target, 2, dim=0)\n",
    "\n",
    "            instance_loss = \\\n",
    "                F.mse_loss(\n",
    "                    model_pred.float(),\n",
    "                    target.float(),\n",
    "                    reduction=\"mean\"\n",
    "                )\n",
    "            \n",
    "            prior_loss = \\\n",
    "                F.mse_loss(\n",
    "                    model_pred_prior.float(),\n",
    "                    target_prior.float(),\n",
    "                    eduction=\"mean\"\n",
    "                )\n",
    "            \n",
    "            loss = \\\n",
    "                instance_loss + \\\n",
    "                trainer.hyperparameters.prior_loss_weight * \\\n",
    "                prior_loss\n",
    "            \n",
    "            trainer.accelerator.backward(loss)\n",
    "            optimizer.step()\n",
    "            lr_scheduler.step()\n",
    "            optimizer.zero_grad()\n",
    "            global_step +=1\n",
    "\n",
    "        loss_metrics = {\n",
    "            \"loss\": loss.detach().item,\n",
    "            \"prior_loss\": prior_loss.detach().item,\n",
    "            \"lr\": lr_scheduler.get_last_lr()[0],\n",
    "        }\n",
    "\n",
    "        experiment.log_metrics(loss_metrics, step=global_step)\n",
    "\n",
    "        progress_bar.set_postfix(**loss_metrics)\n",
    "        progress_bar.update(1)\n",
    "\n",
    "\n",
    "        if global_step >= trainer.hyperparameters.max_train_steps:\n",
    "            break\n",
    "\n",
    "    trainer.save_lora_weights(unet)\n",
    "experiment.add_tag(f\"dreambooth-training\")\n",
    "experiment.log_parameteres(trainer.hyperparameters)\n",
    "trainer.accelerator.end_training()\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Retrieve the training results\n",
    "- You can get the training results using the experiment tracking tool, Comet."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "training_experiment = \\\n",
    "    comet_ml.APIExperiment(\n",
    "        previous_experiment=\"d92519b1f657497e8569a2c8e989b457\"\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "# See the experiment\n",
    "training_experiment.display()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* Prompts to generate images of Andrew."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 318
   },
   "outputs": [],
   "source": [
    "prompts = [\n",
    "    \"a photo of a [V] man playing basketball\",\n",
    "    \"a photo of a [V] man riding a horse\",\n",
    "    \"a photo of a [V] man at the summit of a mountain\",\n",
    "    \"a photo of a [V] man driving a convertible\",\n",
    "    \"a photo of a [V] man riding a skateboard on a huge halfpipe\",\n",
    "    \"a mural of a [V] man, painted by graffiti artists\"\n",
    "]\n",
    "\n",
    "validation_prompts = [\n",
    "    \"a photo of a man playing basketball\",\n",
    "    \"a photo of a man riding a horse\",\n",
    "    \"a photo of a man at the summit of a mountain\",\n",
    "    \"a photo of a man driving a convertible\",\n",
    "    \"a photo of a man riding a skateboard on a huge halfpipe\",\n",
    "    \"a mural of a man, painted by graffiti artists\"\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Note\n",
    "- The folowing code requires GPUs."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```Python\n",
    "from diffusers import DiffusionPipeline\n",
    "\n",
    "pipeline = DiffusionPipeline.from_pretrained(\"runwayml/stable-diffusion-v1-5\")\n",
    "pipeline.load_lora_weights(\"./andrew-model\")\n",
    "\n",
    "for prompt in prompts:\n",
    "    with torch.no_grad():\n",
    "        images = pipeline(\n",
    "            prompt = prompt,\n",
    "        ).images\n",
    "\n",
    "        experiment.log_image(images[0], metadata={\n",
    "            \"prompt\": prompt,\n",
    "            \"model\": hyperparameters.pretrained_model_name_or_path,\n",
    "        })\n",
    "\n",
    "for prompt in validation_prompts:\n",
    "    with torch.no_grad():\n",
    "        images = pipeline(\n",
    "            prompt=prompt,\n",
    "        ).images\n",
    "\n",
    "    experiment.log_image(images[0], metadata={\n",
    "            \"prompt\": prompt,\n",
    "            \"model\": hyperparameters.pretrained_model_name_or_path,\n",
    "        })\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Retrieve the image generation results\n",
    "- You can view the results of image generation regardless of whether you have access to GPUs, using the experiment tracking tool."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "inference_experiment = comet_ml.APIExperiment(\n",
    "        previous_experiment=\"0eb292126ab5476ab0c863061a400bdc\"\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "# See the experiment\n",
    "inference_experiment.display(tab=\"images\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Additional Resources\n",
    "* For more on how to use [Comet](https://www.comet.com/site/?utm_source=dlai&utm_medium=course&utm_campaign=prompt_engineering_for_vision_models&utm_content=dlai_L5) for experiment tracking, check out this [Quickstart Guide](https://colab.research.google.com/drive/1jj9BgsFApkqnpPMLCHSDH-5MoL_bjvYq?usp=sharing) and the [Comet Docs](https://www.comet.com/docs/v2/?utm_source=dlai&utm_medium=course&utm_campaign=prompt_engineering_for_vision_models&utm_content=dlai_L5).\n",
    "* This course was based off a set of two blog articles from Comet. Explore them here for more on how to use newer versions of Stable Diffusion in this pipeline, additional tricks to improve your inpainting results, and a breakdown of the pipeline architecture:\n",
    "  * [SAM + Stable Diffusion for Text-to-Image Inpainting](https://www.comet.com/site/blog/sam-stable-diffusion-for-text-to-image-inpainting/?utm_source=dlai&utm_medium=course&utm_campaign=prompt_engineering_for_vision_models&utm_content=dlai_L5)\n",
    "  * [Image Inpainting for SDXL 1.0 Base Model + Refiner](https://www.comet.com/site/blog/image-inpainting-for-sdxl-1-0-base-refiner/?utm_source=dlai&utm_medium=course&utm_campaign=prompt_engineering_for_vision_models&utm_content=dlai_L5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Did you like this course?\n",
    "\n",
    "- If you liked this course, could you consider giving a rating and share what you liked? 💕\n",
    "- If you did not like this course, could you also please share what you think could have made it better? 🙏\n",
    "\n",
    "#### A note about the \"Course Review\" page.\n",
    "The rating options are from 0 to 10, and used to calculate the \"Net Promoter Score\"\n",
    "- A score of 9 or 10 means you like the course.💫 💕\n",
    "- A score of 7 or 8 means you feel neutral about the course (neither like nor dislike). 🙄\n",
    "- A score of 0,1,2,3,4,5 or 6 all mean that you do not like the course. 😭 "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
